{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import backend as K\n",
    "from tensorflow.keras import activations\n",
    "from tensorflow.keras.layers import Layer, Input, Embedding, LSTM, Dense, Attention\n",
    "from tensorflow.keras.models import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_units):\n",
    "        super(Encoder, self).__init__()\n",
    "        # Embedding Layer\n",
    "        self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)\n",
    "        # Encode LSTM Layer\n",
    "        self.encoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name=\"encode_lstm\")\n",
    "        \n",
    "    def call(self, inputs):\n",
    "        encoder_embed = self.embedding(inputs)\n",
    "        encoder_outputs, state_h, state_c = self.encoder_lstm(encoder_embed)\n",
    "        return encoder_outputs, state_h, state_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(keras.Model):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_units):\n",
    "        super(Decoder, self).__init__()\n",
    "        # Embedding Layer\n",
    "        self.embedding = Embedding(vocab_size, embedding_dim, mask_zero=True)\n",
    "        # Decode LSTM Layer\n",
    "        self.decoder_lstm = LSTM(hidden_units, return_sequences=True, return_state=True, name=\"decode_lstm\")\n",
    "        # Attention Layer\n",
    "        self.attention = Attention()\n",
    "    \n",
    "    def call(self, enc_outputs, dec_inputs, states_inputs):\n",
    "        decoder_embed = self.embedding(dec_inputs)\n",
    "        dec_outputs, dec_state_h, dec_state_c = self.decoder_lstm(decoder_embed, initial_state=states_inputs)\n",
    "        attention_output = self.attention([dec_outputs, enc_outputs])\n",
    "        \n",
    "        return attention_output, dec_state_h, dec_state_c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size):\n",
    "    \"\"\"\n",
    "    seq2seq model\n",
    "    \"\"\"\n",
    "    # Input Layer\n",
    "    encoder_inputs = Input(shape=(maxlen,), name=\"encode_input\")\n",
    "    decoder_inputs = Input(shape=(None,), name=\"decode_input\")\n",
    "    # Encoder Layer\n",
    "    encoder = Encoder(vocab_size, embedding_dim, hidden_units)\n",
    "    enc_outputs, enc_state_h, enc_state_c = encoder(encoder_inputs)\n",
    "    dec_states_inputs = [enc_state_h, enc_state_c]\n",
    "    # Decoder Layer\n",
    "    decoder = Decoder(vocab_size, embedding_dim, hidden_units)\n",
    "    attention_output, dec_state_h, dec_state_c = decoder(enc_outputs, decoder_inputs, dec_states_inputs)\n",
    "    # Dense Layer\n",
    "    dense_outputs = Dense(vocab_size, activation='softmax', name=\"dense\")(attention_output)\n",
    "    # seq2seq model\n",
    "    model = Model(inputs=[encoder_inputs, decoder_inputs], outputs=dense_outputs)\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_vocab(vocab_path):\n",
    "    vocab_words = []\n",
    "    with open(vocab_path, \"r\", encoding=\"utf8\") as f:\n",
    "        for line in f:\n",
    "            vocab_words.append(line.strip())\n",
    "    return vocab_words\n",
    "\n",
    "def read_data(data_path):\n",
    "    datas = []\n",
    "    with open(data_path, \"r\", encoding=\"utf8\") as f:\n",
    "        for line in f:\n",
    "            words = line.strip().split()\n",
    "            datas.append(words)\n",
    "    return datas\n",
    "\n",
    "def process_data_index(datas, vocab2id):\n",
    "    data_indexs = []\n",
    "    for words in datas:\n",
    "        line_index = [vocab2id[w] if w in vocab2id else vocab2id[\"<UNK>\"] for w in words]\n",
    "        data_indexs.append(line_index)\n",
    "    return data_indexs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocab test:  ['<PAD>', '<UNK>', '<GO>', '<EOS>', '呵呵', '不是', '怎么', '了', '开心', '点']\n",
      "source test:  ['许兵', '是', '谁']\n",
      "source index:  [26, 27, 24]\n",
      "target test:  ['是', '我', '善良', '可爱', '的', '主人', '的', '老公', '啊']\n",
      "target index:  [27, 16, 9572, 436, 45, 452, 45, 274, 111]\n"
     ]
    }
   ],
   "source": [
    "vocab_words = read_vocab(\"data/ch_word_vocab.txt\")\n",
    "special_words = [\"<PAD>\", \"<UNK>\", \"<GO>\", \"<EOS>\"]\n",
    "vocab_words = special_words + vocab_words\n",
    "vocab2id = {word: i for i, word in enumerate(vocab_words)}\n",
    "id2vocab = {i: word for i, word in enumerate(vocab_words)}\n",
    "\n",
    "num_sample = 1000\n",
    "source_data = read_data(\"data/ch_source_data_seg.txt\")[:num_sample]\n",
    "source_data_ids = process_data_index(source_data, vocab2id)\n",
    "target_data = read_data(\"data/ch_target_data_seg.txt\")[:num_sample]\n",
    "target_data_ids = process_data_index(target_data, vocab2id)\n",
    "\n",
    "print(\"vocab test: \", [id2vocab[i] for i in range(10)])\n",
    "print(\"source test: \", source_data[10])\n",
    "print(\"source index: \", source_data_ids[10])\n",
    "print(\"target test: \", target_data[10])\n",
    "print(\"target index: \", target_data_ids[10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder inputs:  [[2, 4, 3], [2, 5, 3]]\n",
      "decoder inputs:  [[2, 27, 37846, 756, 45, 180], [2, 38, 27, 84, 49272]]\n",
      "decoder outputs:  [[27, 37846, 756, 45, 180, 3], [38, 27, 84, 49272, 3]]\n"
     ]
    }
   ],
   "source": [
    "def process_input_data(source_data_ids, target_indexs, vocab2id):\n",
    "    source_inputs = []\n",
    "    decoder_inputs, decoder_outputs = [], []\n",
    "    for source, target in zip(source_data_ids, target_indexs):\n",
    "        source_inputs.append([vocab2id[\"<GO>\"]] + source + [vocab2id[\"<EOS>\"]])\n",
    "        decoder_inputs.append([vocab2id[\"<GO>\"]] + target)\n",
    "        decoder_outputs.append(target + [vocab2id[\"<EOS>\"]])\n",
    "    return source_inputs, decoder_inputs, decoder_outputs\n",
    "\n",
    "source_input_ids, target_input_ids, target_output_ids = process_input_data(source_data_ids, target_data_ids, vocab2id)\n",
    "print(\"encoder inputs: \", source_input_ids[:2])\n",
    "print(\"decoder inputs: \", target_input_ids[:2])\n",
    "print(\"decoder outputs: \", target_output_ids[:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[4], [5], [6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 11, 20]]\n",
      "[[    2    27 37846   756    45   180     0     0     0     0]\n",
      " [    2    38    27    84 49272     0     0     0     0     0]\n",
      " [    2    16  6692    82 49273   320    16   518     0     0]\n",
      " [    2   526     0     0     0     0     0     0     0     0]\n",
      " [   16   438    22   328    19 49272 15817   254  1764 49272]]\n",
      "[[   27 37846   756    45   180     3     0     0     0     0]\n",
      " [   38    27    84 49272     3     0     0     0     0     0]\n",
      " [   16  6692    82 49273   320    16   518     3     0     0]\n",
      " [  526     3     0     0     0     0     0     0     0     0]\n",
      " [  438    22   328    19 49272 15817   254  1764 49272     3]]\n"
     ]
    }
   ],
   "source": [
    "maxlen = 10\n",
    "source_input_ids = keras.preprocessing.sequence.pad_sequences(source_input_ids, padding='post', maxlen=maxlen)\n",
    "target_input_ids = keras.preprocessing.sequence.pad_sequences(target_input_ids, padding='post',  maxlen=maxlen)\n",
    "target_output_ids = keras.preprocessing.sequence.pad_sequences(target_output_ids, padding='post',  maxlen=maxlen)\n",
    "print(source_data_ids[:5])\n",
    "print(target_input_ids[:5])\n",
    "print(target_output_ids[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "encode_input (InputLayer)       [(None, 10)]         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Encoder)               ((None, 10, 128), (N 3598348     encode_input[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "decode_input (InputLayer)       [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Decoder)               ((None, None, 128),  3598348     encoder[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, None, 70134)  9047286     decoder[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 16,243,982\n",
      "Trainable params: 16,243,982\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "K.clear_session()\n",
    "\n",
    "maxlen = 10\n",
    "embedding_dim = 50\n",
    "hidden_units = 128\n",
    "vocab_size = len(vocab2id)\n",
    "\n",
    "model = Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size)\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 800 samples, validate on 200 samples\n",
      "Epoch 1/20\n",
      "800/800 [==============================] - 27s 34ms/sample - loss: 11.1144 - val_loss: 10.9350\n",
      "Epoch 2/20\n",
      "800/800 [==============================] - 21s 27ms/sample - loss: 9.5005 - val_loss: 8.4937\n",
      "Epoch 3/20\n",
      "800/800 [==============================] - 23s 29ms/sample - loss: 7.9150 - val_loss: 8.5811\n",
      "Epoch 4/20\n",
      "800/800 [==============================] - 23s 29ms/sample - loss: 7.7800 - val_loss: 8.6219\n",
      "Epoch 5/20\n",
      "800/800 [==============================] - 24s 30ms/sample - loss: 7.7378 - val_loss: 8.6402\n",
      "Epoch 6/20\n",
      "800/800 [==============================] - 24s 30ms/sample - loss: 7.7115 - val_loss: 8.6691\n",
      "Epoch 7/20\n",
      "800/800 [==============================] - 24s 31ms/sample - loss: 7.6915 - val_loss: 8.6835\n",
      "Epoch 8/20\n",
      "800/800 [==============================] - 24s 31ms/sample - loss: 7.6670 - val_loss: 8.7008\n",
      "Epoch 9/20\n",
      "800/800 [==============================] - 24s 31ms/sample - loss: 7.6509 - val_loss: 8.6933\n",
      "Epoch 10/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.6304 - val_loss: 8.6807\n",
      "Epoch 11/20\n",
      "800/800 [==============================] - 24s 31ms/sample - loss: 7.6100 - val_loss: 8.6684\n",
      "Epoch 12/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.5927 - val_loss: 8.6545\n",
      "Epoch 13/20\n",
      "800/800 [==============================] - 24s 30ms/sample - loss: 7.5750 - val_loss: 8.6396\n",
      "Epoch 14/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.5568 - val_loss: 8.6251\n",
      "Epoch 15/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.5348 - val_loss: 8.6122\n",
      "Epoch 16/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.5171 - val_loss: 8.5965\n",
      "Epoch 17/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.4975 - val_loss: 8.5799\n",
      "Epoch 18/20\n",
      "800/800 [==============================] - 24s 30ms/sample - loss: 7.4786 - val_loss: 8.5694\n",
      "Epoch 19/20\n",
      "800/800 [==============================] - 25s 31ms/sample - loss: 7.4577 - val_loss: 8.5504\n",
      "Epoch 20/20\n",
      "800/800 [==============================] - 24s 31ms/sample - loss: 7.4396 - val_loss: 8.5273\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x19bddc78a90>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "epochs = 20\n",
    "batch_size = 32\n",
    "val_rate = 0.2\n",
    "\n",
    "loss_fn = keras.losses.SparseCategoricalCrossentropy()\n",
    "model.compile(loss=loss_fn, optimizer='adam')\n",
    "model.fit([source_input_ids, target_input_ids], target_output_ids, \n",
    "          batch_size=batch_size, epochs=epochs, validation_split=val_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_weights(\"data/seq2seq_attention_weights.h5\")\n",
    "del model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "encode_input (InputLayer)       [(None, 10)]         0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Encoder)               ((None, 10, 128), (N 3598348     encode_input[0][0]               \n",
      "__________________________________________________________________________________________________\n",
      "decode_input (InputLayer)       [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Decoder)               ((None, None, 128),  3598348     encoder[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, None, 70134)  9047286     decoder[0][0]                    \n",
      "==================================================================================================\n",
      "Total params: 16,243,982\n",
      "Trainable params: 16,243,982\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "K.clear_session()\n",
    "\n",
    "model = Seq2Seq(maxlen, embedding_dim, hidden_units, vocab_size)\n",
    "model.load_weights(\"data/seq2seq_attention_weights.h5\")\n",
    "print(model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_1\"\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "encode_input (InputLayer)    [(None, 10)]              0         \n",
      "_________________________________________________________________\n",
      "encoder (Encoder)            ((None, 10, 128), (None,  3598348   \n",
      "=================================================================\n",
      "Total params: 3,598,348\n",
      "Trainable params: 3,598,348\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "def encoder_infer(model):\n",
    "    encoder_model = Model(inputs=model.get_layer('encoder').inputs, \n",
    "                        outputs=model.get_layer('encoder').outputs)\n",
    "    return encoder_model\n",
    "\n",
    "encoder_model = encoder_infer(model)\n",
    "print(encoder_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"model_2\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "enc_output (InputLayer)         [(None, 10, 128)]    0                                            \n",
      "__________________________________________________________________________________________________\n",
      "decode_input (InputLayer)       [(None, None)]       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_state_h (InputLayer)      [(None, 128)]        0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_state_c (InputLayer)      [(None, 128)]        0                                            \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Decoder)               ((None, None, 128),  3598348     enc_output[0][0]                 \n",
      "__________________________________________________________________________________________________\n",
      "dense (Dense)                   (None, None, 70134)  9047286     decoder[1][0]                    \n",
      "==================================================================================================\n",
      "Total params: 12,645,634\n",
      "Trainable params: 12,645,634\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "def decoder_infer(model, encoder_model):\n",
    "    encoder_output = encoder_model.get_layer('encoder').output[0]\n",
    "    maxlen, hidden_units = encoder_output.shape[1:]\n",
    "    \n",
    "    dec_input = model.get_layer('decode_input').input\n",
    "    enc_output = Input(shape=(maxlen, hidden_units), name='enc_output')\n",
    "    dec_input_state_h = Input(shape=(hidden_units,), name='input_state_h')\n",
    "    dec_input_state_c = Input(shape=(hidden_units,), name='input_state_c')\n",
    "    dec_input_states = [dec_input_state_h, dec_input_state_c]\n",
    "\n",
    "    decoder = model.get_layer('decoder')\n",
    "    dec_outputs, out_state_h, out_state_c = decoder(enc_output, dec_input, dec_input_states)\n",
    "    dec_output_states = [out_state_h, out_state_c]\n",
    "\n",
    "    decoder_dense = model.get_layer('dense')\n",
    "    dense_output = decoder_dense(dec_outputs)\n",
    "\n",
    "    decoder_model = Model(inputs=[enc_output, dec_input, dec_input_states], \n",
    "                          outputs=[dense_output]+dec_output_states)\n",
    "    return decoder_model\n",
    "\n",
    "decoder_model = decoder_infer(model, encoder_model)\n",
    "print(decoder_model.summary())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "maxlen = 10\n",
    "\n",
    "def infer_predict(input_text, encoder_model, decoder_model):\n",
    "    text_words = input_text.split()[:maxlen]\n",
    "    input_id = [vocab2id[w] if w in vocab2id else vocab2id[\"<UNK>\"] for w in text_words]\n",
    "    input_id = [vocab2id[\"<GO>\"]] + input_id + [vocab2id[\"<EOS>\"]]\n",
    "    if len(input_id) < maxlen:\n",
    "        input_id = input_id + [vocab2id[\"<PAD>\"]] * (maxlen-len(input_id))\n",
    "\n",
    "    input_source = np.array([input_id])\n",
    "    input_target = np.array([vocab2id[\"<GO>\"]])\n",
    "    \n",
    "    # 编码器encoder预测输出\n",
    "    enc_outputs, enc_state_h, enc_state_c = encoder_model.predict([input_source])\n",
    "    dec_inputs = input_target\n",
    "    dec_states_inputs = [enc_state_h, enc_state_c]\n",
    "\n",
    "    result_id = []\n",
    "    result_text = []\n",
    "    for i in range(maxlen):\n",
    "        # 解码器decoder预测输出\n",
    "        dense_outputs, dec_state_h, dec_state_c = decoder_model.predict([enc_outputs, dec_inputs]+dec_states_inputs)\n",
    "        pred_id = np.argmax(dense_outputs[0][0])\n",
    "        result_id.append(pred_id)\n",
    "        result_text.append(id2vocab[pred_id])\n",
    "        if id2vocab[pred_id] == \"<EOS>\":\n",
    "            break\n",
    "        dec_inputs = np.array([[pred_id]])\n",
    "        dec_states_inputs = [dec_state_h, dec_state_c]\n",
    "    return result_id, result_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input:  你 在 干 什么 呢\n",
      "Output:  ['<EOS>'] [3]\n"
     ]
    }
   ],
   "source": [
    "input_text = \"你 在 干 什么 呢\"\n",
    "result_id, result_text = infer_predict(input_text, encoder_model, decoder_model)\n",
    "\n",
    "print(\"Input: \", input_text)\n",
    "print(\"Output: \", result_text, result_id)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
